Diffusion-QL suffers from two critical limitations
computationally inefficient to forward and backward through the whole Markov chain during training
incompatible with maximum likelihood-based RL algorithms (e.g., policy gradient methods) as the likelihood of diffusion models is intractable.
EDP approximately constructs actions from corrupted ones at training to avoid running the sampling chain.
Efficient Diffusion Policy
present a novel algorithm termed Reinforcement-Guided Diffusion Policy Learning (RGDPL)
Diffusion Policy
We use the reverse process of a conditional diffusion model as a parametric policy:
πθ(a∣s)=pθ(a0:K∣s)=p(aK)k=1∏Kpθ(ak−1∣ak,s),
where aK∼N(0,I).
Given a dataset, we can easily and efficiently train a diffusion policy in a behavior-cloning manner as we only need to forward and backward through the network once each iteration.
Reinforcement-Guided Diffusion Policy Learning
how we can efficiently use Qϕ to guide diffusion policy training procedure.
We now show that this can be achieved without sampling actions from diffusion policies.
Using the reparameterization trick, we are able to connect ak,a0 and ϵ by:
ak=αˉka0+1−αˉkϵ,ϵ∼N(0,I).
Recall that our diffusion policy is parameterized to predict ϵ with ϵθ(ak,k;s). By relacing ϵ with ϵθ(ak,k;s), we obtain the approximated action:
a^0=αˉk1ak−αˉk1−αˉkϵθ(ak,k;s).
Accordingly, the policy improvement for diffusion policies is modified as follows:
Lπ(θ)=−Es∼D,a^0[Qϕ(s,a^0)].
To improve the efficiency of policy evaluation, we propose to replace the DDPM sampling with DPM-Solver [20], which is an ODE-based sampler.
Generalization to Various RL algorithms
Direct policy optimization. It maximizes Q values and directly backpropagate the gradients from Q network to policy network.
∇θLπ(θ)=−∂a∂Qϕ(s,a)∂θ∂a.
This is only applicable to cases where ∂θ∂a is tractable, e.g., when a deterministic policy a=πθ(s) is used or when the sampling process can be reparametrized.
Likelihood-based policy optimization. It tries to distill the knowledge from the Q network into the policy network indirectly by performing weighted regression or weighted maximum likelihood
θmaxE(s,a)∼D[f(Qϕ(s,a))logπθ(a∣s)],
where f(Qϕ(s,a)) is a monotonically increasing function that assigns a weight to each state- action where f(Qϕ(s,a)) is a monotonically increasing function that assigns a weight to each state- action serf(Qϕ(s,a)) is pair in the dataset. This objective requires the log-likelihood of the policy to be tractable and differentiable.
In this paper
First, instead of computing the likelihood, we turn to a lower bound for log πθ(a∣s) introduced in DDPM. By discarding the constant term that does not depend on θ,we can have the objective:
Second, instead of directly optimizing logπθ(a∣s), we propose to replace it with an approximated policy π^θ(a∣s)≜N(a^0,I). Then, we get the following objective:
Ek,ϵ,(a,s)[f(Qϕ(s,a))a−a^02].
Empirically, we find these two choices perform similarly, but the latter is easier to implement. So we will report results mainly based on the second realization. In our experiments, we consider two offline RL algorithms under this category, i.e., CRR, and IQL. They use two weighting schemes: fCRR=exp[(Qϕ(s,a)−Ea′∼π^(a∣s)Q(s,a′))/τCRR] and fIQL=exp[(Qϕ(s,a)−Vψ(s))/τIQL], where τ refers to the temperature parameter and Vψ(s) is an additional value network parameterized by ψ.